#include <gtest/gtest.h>
#include "aicpu_engine/engine/aicpu_engine.h"

#include "util/util.h"
#include "config/config_file.h"
#include "stub.h"

#include "ge/ge_api_types.h"

using namespace aicpu;
using namespace ge;
using namespace std;

TEST(AicpuKernelInfo, Initialize_SUCCESS)
{
    map<string, string> options;
    options[SOC_VERSION] = "Ascend910";
    const string configFilePath = "air/test/engines/cpueng/stub/";
    ASSERT_EQ(Initialize(options), SUCCESS);
    string kernelConfig;
    ASSERT_EQ(ConfigFile::GetInstance().GetValue("DNN_VM_AICPU_ASCENDOpsKernel", kernelConfig), true);
    ASSERT_EQ(kernelConfig, "CUSTAICPUKernel,AICPUKernel");
    string optimizerConfig;
    ASSERT_EQ(ConfigFile::GetInstance().GetValue("DNN_VM_AICPU_ASCENDGraphOptimizer", optimizerConfig), true);
    ASSERT_EQ(optimizerConfig, "AICPUOptimizer");
    map<string, OpsKernelInfoStorePtr> opsKernelInfoStores;
    GetOpsKernelInfoStores(opsKernelInfoStores);
    ASSERT_NE(opsKernelInfoStores["aicpu_ascend_kernel"], nullptr);
    ASSERT_EQ(opsKernelInfoStores["aicpu_ascend_kernel"]->Initialize(options), SUCCESS);
}

TEST(AicpuKernelInfo, GetAllOpsKernelInfo_SUCCESS)
{
    map<string, OpsKernelInfoStorePtr> opsKernelInfoStores;
    GetOpsKernelInfoStores(opsKernelInfoStores);
    map<string, OpInfo> infos;
    ASSERT_EQ(infos.size(), 0);
    opsKernelInfoStores["aicpu_ascend_kernel"]->GetAllOpsKernelInfo(infos);
    ASSERT_NE(infos.size(), 0);
}

TEST(AicpuKernelInfo, CheckSupported_SUCCESS)
{
    map<string, OpsKernelInfoStorePtr> opsKernelInfoStores;
    GetOpsKernelInfoStores(opsKernelInfoStores);

    OpDescPtr opDescPtr = make_shared<OpDesc>("Add","Add");
    vector<int64_t> tensorShape = {1,1,3,1};
    GeTensorDesc tensor1(GeShape(tensorShape), FORMAT_NCHW, DT_INT32);
    opDescPtr->AddInputDesc("x", tensor1);
    opDescPtr->AddInputDesc("y", tensor1);
    opDescPtr->AddOutputDesc("z", tensor1);
    string unSupportedReason;
    ASSERT_EQ(opsKernelInfoStores["aicpu_ascend_kernel"]->CheckSupported(opDescPtr, unSupportedReason), true);
    ASSERT_EQ(unSupportedReason, "");
}

TEST(AicpuKernelInfo, opsFlagCheck_SUCCESS)
{
    map<string, OpsKernelInfoStorePtr> opsKernelInfoStores;
    GetOpsKernelInfoStores(opsKernelInfoStores);

    OpDescPtr opDescPtr = make_shared<OpDesc>("Add","Add");
    vector<int64_t> tensorShape = {1,1,3,1};
    GeTensorDesc tensor1(GeShape(tensorShape), FORMAT_NCHW, DT_INT32);
    opDescPtr->AddInputDesc("x", tensor1);
    opDescPtr->AddInputDesc("y", tensor1);
    opDescPtr->AddOutputDesc("z", tensor1);
    shared_ptr<ComputeGraph> graphPtr = make_shared<ComputeGraph>("test_graph");
    string opsFlag;
    opsKernelInfoStores["aicpu_ascend_kernel"]->opsFlagCheck(*(graphPtr->AddNode(opDescPtr)), opsFlag);
    ASSERT_EQ(opsFlag, "");
}
